返回博客列表

Untitled

1. 导入必要的模块

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
import torch.nn.functional as F

2. 初始化 WandbLogger

  • 替代手动 wandb.init:通过 WandbLogger 配置所有参数。
  • 关键参数
  • project:WandB 项目名称。
  • name:实验名称(支持动态生成,如 f"{args.method}_{args.dataset}")。
  • group:实验分组(可选)。
  • save_code:是否保存代码快照(默认 True)。
  • config:记录超参数(如 args)。
wandb_logger = WandbLogger(
    project="ESN",
    name=f"{args.method}_{args.dataset}",  # 动态名称示例
    group=f"{args.dataset}_group",
    notes="实验备注",
    save_code=True,
    config=args  # 记录所有超参数
)

3. 配置 Trainer

  • 关键点:将 WandbLogger 实例传递给 logger 参数。
  • 其他配置:如训练设备、最大轮次、回调等。
trainer = Trainer(
    logger=wandb_logger,  # 关键行:启用 WandB 日志
    default_root_dir="./checkpoints/",
    accelerator="gpu",
    devices=1,
    max_epochs=100,
    callbacks=[
        ModelCheckpoint(...),  # 模型保存回调
        LearningRateMonitor(...),  # 学习率监控
    ]
)

4. 在 LightningModule 中记录损失

  • 使用 self.log:在 training_step 中记录损失值。
  • 避免名称冲突:使用自定义名称(如 train_loss_epoch)而非默认的 train_loss
  • 参数一致性:确保所有 self.log 调用对同一名称的参数一致。
class MyModel(LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        # 记录训练损失(自定义名称)
        self.log(
            "train_loss_epoch",  # 唯一名称,避免冲突
            loss,
            on_step=False,  # 不记录每一步
            on_epoch=True,  # 记录每个 epoch 的平均值
            prog_bar=True,  # 在进度条显示
            logger=True      # 同步到 WandB
        )
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

5. 运行训练并验证

  • 启动训练

python model = MyModel() trainer.fit(model, train_dataloader)

  • 检查 WandB

  • 登录 WandB 网页端,查看对应项目。

  • Metrics 面板查看 train_loss_epoch 曲线。
  • Overview 面板查看超参数和代码快照。

常见问题排查

  1. 指标未显示
    - 确认 logger=True 已设置。
    - 检查 WandbLogger 是否传递给 Trainer
  2. 名称冲突报错
    - 确保所有 self.log 调用中名称唯一且参数一致。
    - 避免使用 Lightning 的默认名称(如 train_loss)。
  3. WandB 无数据
    - 运行 wandb login 确保已登录。
    - 检查网络连接或 API 密钥有效性。

最终效果

  • 自动记录:每个 epoch 的 train_loss_epoch 自动同步到 WandB。
  • 交互式曲线:WandB 自动生成可缩放、可筛选的损失曲线。
  • 实验管理:所有超参数和代码版本均被记录,确保实验可复现。

评论